{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b72b1d2b",
   "metadata": {},
   "source": [
    "### 代码实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6ba3f730",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入必要的库，torchinfo用于查看模型结构\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchinfo import summary"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "257bddce",
   "metadata": {},
   "source": [
    "### 模型定义"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d87c584f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义LeNet的网络结构\n",
    "class LeNet(nn.Module):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super(LeNet, self).__init__()\n",
    "        # 卷积层1：输入1个通道，输出6个通道，卷积核大小为5x5\n",
    "        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)\n",
    "        # 卷积层2：输入6个通道，输出16个通道，卷积核大小为5x5\n",
    "        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)\n",
    "        # 全连接层1：输入16x4x4=256个节点，输出120个节点，由于输入数据略有差异，修改为16x4x4\n",
    "        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)\n",
    "        # 全连接层2：输入120个节点，输出84个节点\n",
    "        self.fc2 = nn.Linear(in_features=120, out_features=84)\n",
    "        # 输出层：输入84个节点，输出10个节点\n",
    "        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 使用ReLU激活函数，并进行最大池化\n",
    "        x = torch.relu(self.conv1(x))\n",
    "        x = nn.functional.max_pool2d(x, kernel_size=2)\n",
    "        # 使用ReLU激活函数，并进行最大池化\n",
    "        x = torch.relu(self.conv2(x))\n",
    "        x = nn.functional.max_pool2d(x, kernel_size=2)\n",
    "        # 将多维张量展平为一维张量\n",
    "        x = x.view(-1, 16 * 4 * 4)\n",
    "        # 全连接层\n",
    "        x = torch.relu(self.fc1(x))\n",
    "        # 全连接层\n",
    "        x = torch.relu(self.fc2(x))\n",
    "        # 全连接层\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0a6dbeb",
   "metadata": {},
   "source": [
    "### 网络结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "42692239",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "==========================================================================================\n",
       "Layer (type:depth-idx)                   Output Shape              Param #\n",
       "==========================================================================================\n",
       "LeNet                                    [1, 10]                   --\n",
       "├─Conv2d: 1-1                            [1, 6, 24, 24]            156\n",
       "├─Conv2d: 1-2                            [1, 16, 8, 8]             2,416\n",
       "├─Linear: 1-3                            [1, 120]                  30,840\n",
       "├─Linear: 1-4                            [1, 84]                   10,164\n",
       "├─Linear: 1-5                            [1, 10]                   850\n",
       "==========================================================================================\n",
       "Total params: 44,426\n",
       "Trainable params: 44,426\n",
       "Non-trainable params: 0\n",
       "Total mult-adds (M): 0.29\n",
       "==========================================================================================\n",
       "Input size (MB): 0.00\n",
       "Forward/backward pass size (MB): 0.04\n",
       "Params size (MB): 0.18\n",
       "Estimated Total Size (MB): 0.22\n",
       "=========================================================================================="
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看模型结构及参数量，input_size表示示例输入数据的维度信息\n",
    "summary(LeNet(), input_size=(1, 1, 28, 28))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25bbfb1c",
   "metadata": {},
   "source": [
    "### 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "04cbc625",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0 Loss: 2.7325645021239664 Acc: 0.2633\n",
      "Epoch: 2 Loss: 2.630008887238046 Acc: 0.6901  \n",
      "Epoch: 4 Loss: 1.9096679044736495 Acc: 0.9047 \n",
      "Epoch: 6 Loss: 1.7179356540642037 Acc: 0.9424 \n",
      "Epoch: 8 Loss: 1.5851480201856594 Acc: 0.9413 \n",
      "100%|██████████| 10/10 [01:46<00:00, 10.65s/it]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8+yak3AAAACXBIWXMAAAsTAAALEwEAmpwYAAAkNElEQVR4nO3deXxV9Z3/8dcnyU1CVrKxhABhk7KGJSyiImhdsYpah2rVgmsrtFbbOra1v06XmXbqtNOpOioqVWq1i2Jrp1qnKg5oUQkIIgKyKglLQkJWSMjy/f1xLiHBQBJIcnJv3s/H4z7uveece+4nR/Pmm+/5nu8x5xwiIhL6IvwuQEREOoYCXUQkTCjQRUTChAJdRCRMKNBFRMJElF9fnJ6e7rKzs/36ehGRkLRmzZoDzrmMltb5FujZ2dnk5eX59fUiIiHJzD4+0Tp1uYiIhAkFuohImFCgi4iECd/60EUkvNXW1pKfn091dbXfpYSk2NhYsrKyCAQCbf6MAl1EOkV+fj6JiYlkZ2djZn6XE1KccxQXF5Ofn8+QIUPa/Dl1uYhIp6iuriYtLU1hfgrMjLS0tHb/daNAF5FOozA/dady7EKuy+Wj/RX8z/o9mBkRZkQYREQYZhx7bxZcf2xZs+2tyfYR7dy+yfqkXlGM7JtIVKT+XRQR/4VcoG/dX8kDy7fRXaZxj4uOZNKgFKZkpzIlO4UJg3oTFx1yh1UkLCUkJFBZWel3GV0m5JJnzvj+zBk/B+cczkGDczQEn4+995a5Jusajt++oZ3bO3dsfYP3XFRZw5pdJazedZBfvvYRzkFkhDE2M4ncYMBPHpxKRmKM34dNRHqAkAv0o+xoNwj+9tFdnpMJQHl1Le99UsrqnSWs3lXC029/zBNv7gRgSHo8uYNTmDIklSnZqWSnxalvUaQLOee45557ePnllzEz7rvvPubNm8fevXuZN28e5eXl1NXV8fDDDzNjxgxuvvlm8vLyMDNuuukm7rrrLr9/hDYJ2UDvbpJiA5x7RgbnnuHNmXOkroEP9pSRF2zBv7ppP39ckw9AekI0uYNTyc32umpGZyYRUD+8hLEf/GUjH+4p79B9js5M4vufG9OmbZctW8a6detYv349Bw4cYMqUKcycOZNnnnmGiy66iO9+97vU19dz6NAh1q1bR0FBAR988AEApaWlHVp3Z1Kgd5LoqAgmDUph0qAUbpvptRC2F1WxepfXgs/bdZC/bdwHQK9AJBMH9SY3O5Wp2alMHNSb+Bj9pxHpKG+++SbXXnstkZGR9O3bl3PPPZfVq1czZcoUbrrpJmpra5k7dy4TJkxg6NCh7Nixg69+9avMmTOHCy+80O/y20yp0UXMjOF9EhjeJ4Frpw4CYH95NXm7DjaG/IOvb6Uh2A8/un9SYws+NzuFPomxPv8EIqeurS3prjZz5kxWrFjBX//6V+bPn8/dd9/NjTfeyPr163nllVd45JFH+MMf/sCSJUv8LrVNFOg+6psUGzzJ2x+AimA//NFummff/YRfv7ULgMFpceQO9k60ThmSytD0ePXDi7TROeecw6OPPsqXvvQlSkpKWLFiBffffz8ff/wxWVlZ3HrrrdTU1LB27VouvfRSoqOjufrqqxk5ciTXX3+93+W3mQK9G0mMDTDzjAxmBvvha+sb2LinnLxdJby7s4TlWwp5fq3XD58aH+2daM1O5aIx/RiUFudn6SLd2pVXXsmqVavIycnBzPjZz35Gv379eOqpp7j//vsJBAIkJCSwdOlSCgoKWLBgAQ0NDQD85Cc/8bn6tjPn04Du3NxcpxtctI9zjh0Hqhpb8Hm7SthVfIjkXgFeXHQWg9Pi/S5RpNGmTZsYNWqU32WEtJaOoZmtcc7ltrS9WughxMwYlpHAsIwE5k3x+uG3FVbw+UdWcevSPJbdcRYJOpkq0mNprFyIG94nkYeum8T2oiru+v06Ghq6ySW0ItLlFOhh4Kzh6dw3ZxR//3A///nqR36XIyI+0d/nYWL+jGw27S3ngde38Zl+SY0jZ0Sk51ALPUyYGT+aO5bJg1P45h/Xs3FPmd8liUgXU6CHkZioSB6+fhK94wLctnQNBypr/C5JRLqQAj3M9EmMZfENuRyorOGOp9dypK7B75JEpIu0GuhmNtDMlpvZh2a20czubGGbWWZWZmbrgo//1znlSluMy0rmZ58fz7u7SviXv2z0uxyRsFdXV+d3CUDbWuh1wDecc6OB6cBCMxvdwnYrnXMTgo8fdmiV0m5XTBjAl88dxjPvfMJv3v7Y73JEfDN37lwmT57MmDFjWLx4MQB/+9vfmDRpEjk5OZx//vkAVFZWsmDBAsaNG8f48eN5/vnnAe8mGUc999xzzJ8/H4D58+fz5S9/mWnTpnHPPffw7rvvcuaZZzJx4kRmzJjBli1bAKivr+eb3/wmY8eOZfz48TzwwAO8/vrrzJ07t3G/f//737nyyitP+2dtdZSLc24vsDf4usLMNgEDgA9P+9ulU33ropF8tL+CH7y4kRF9Epg+NM3vkqSnevle2LehY/fZbxxc8tNWN1uyZAmpqakcPnyYKVOmcMUVV3DrrbeyYsUKhgwZQklJCQA/+tGPSE5OZsMGr86DBw+2uu/8/Hz+8Y9/EBkZSXl5OStXriQqKopXX32V73znOzz//PMsXryYXbt2sW7dOqKioigpKSElJYU77riDoqIiMjIy+PWvf81NN910eseDdvahm1k2MBF4p4XVZ5rZejN72cxanFrNzG4zszwzyysqKmp/tdIukRHGL78wgcFpcdzx27XsLjnkd0kiXe5Xv/oVOTk5TJ8+nd27d7N48WJmzpzJkCFDAEhNTQXg1VdfZeHChY2fS0lJaXXf11xzDZGRkQCUlZVxzTXXMHbsWO666y42btzYuN/bb7+dqKioxu8zM2644QaefvppSktLWbVqFZdccslp/6xtHoduZgnA88DXnXPHz1S/FhjsnKs0s0uBPwEjjt+Hc24xsBi8uVxOtWhpu6TYAI/dmMsVD73FrUvzeP4rMzTXunS9NrSkO8Mbb7zBq6++yqpVq4iLi2PWrFlMmDCBzZs3t3kfTWc1ra6ubrYuPv7Y/Enf+973mD17Ni+88AK7du1i1qxZJ93vggUL+NznPkdsbCzXXHNNY+Cfjja10M0sgBfmv3XOLTt+vXOu3DlXGXz9EhAws/TTrk46xNCMBB68bhIf7a/gm39cr+kBpMcoKysjJSWFuLg4Nm/ezNtvv011dTUrVqxg507vFpFHu1wuuOACHnroocbPHu1y6du3L5s2baKhoYEXXnjhpN81YMAAAJ588snG5RdccAGPPvpo44nTo9+XmZlJZmYmP/7xj1mwYEGH/LxtGeViwBPAJufcL06wTb/gdpjZ1OB+izukQukQ556RwbcvGcXLH+zjweXb/C5HpEtcfPHF1NXVMWrUKO69916mT59ORkYGixcv5qqrriInJ4d58+YBcN9993Hw4EHGjh1LTk4Oy5cvB+CnP/0pl112GTNmzKB//xNfgX3PPffw7W9/m4kTJzYb9XLLLbcwaNAgxo8fT05ODs8880zjui9+8YsMHDiww2albHX6XDM7G1gJbACODmr+DjAIwDn3iJktAr6CNyLmMHC3c+4fJ9uvps/tes45vvGH9Sx7r4BHb5jMRWP6+V2ShDFNn9u6RYsWMXHiRG6++eYW13f49LnOuTeBk94axzn3IPBga/sSf5kZ/3bVOLYfqOLu369j2R1nMbJfot9lifRIkydPJj4+np///Ocdtk9dKdrDxAYiWXzDZOJjorhl6WoOVh3xuySRHmnNmjWsWLGCmJiYDtunAr0H6psUy6M3TGZ/eQ0Ln1lLbb2mB5DO4dcd0cLBqRw7BXoPNXFQCj+5chz/2F7Mv/51k9/lSBiKjY2luLhYoX4KnHMUFxcTGxvbrs9pQHIPdvXkLDbtLefxN3cyqn9i423tRDpCVlYW+fn56CLCUxMbG0tWVla7PqNA7+HuveQzbNlfwX1/+oBhGQnkZqf6XZKEiUAg0Hg1pnQNdbn0cFGRETx47SSyUuL48tNr2FN62O+SROQUKdCF5LgAj904meraBm77TR6Hj9T7XZKInAIFugAwvE8iv7p2Ahv3lHPP8+/rRJZICFKgS6PzPtOXb100kr+s38PD/7fd73JEpJ0U6NLMV84dxuU5mdz/yhZe27Tf73JEpB0U6NKMmfHvV49nTGYSd/5uHdsKK/wuSUTaSIEun9IrOpLFN+QSG4jklqfyKDtU63dJItIGCnRpUWbvXjxy/SQKSg+z6Nm11Gl6AJFuT4EuJ5SbncqP545l5dYD/PTltt/hRUT8oStF5aTmTRnEpr0VwekBkrh6cvsuRRaRrqMWurTqu3NGMWNYGt9+YQPvfdL6ndBFxB8KdGlVIDKCh66bRL+kWG7/zRr2l1e3/iER6XIKdGmTlPhoHrsxl6qaOm77zRqqazU9gEh3o0CXNhvZL5H/nDeB9btL+c6yDZoeQKSbUaBLu1w4ph93X3AGy94r4PGVO/0uR0SaUKBLu331vOFcOq4fP3l5E//3kW5eINJdKNCl3cyM/7gmh5H9klj0zFp2FFX6XZKIoECXUxQXHcVjN04mEBnBLUvzKK/W9AAiflOgyynLSonj4S9O4pPiQ9z57HvUN+gkqYifFOhyWqYNTeNfLh/D8i1F3P6bNWzaW+53SSI9li79l9N2/fTBVFTX8dDybVzyXyv57Kg+LJw9nImDUvwuTaRHMb/GEufm5rq8vDxfvls6R9mhWp5atYslb+2k9FAtZw1PY+Hs4Zw5NA0z87s8kbBgZmucc7ktrlOgS0erqqnjmXc+YfHKHRRV1DBpUG8WnTec2SP7KNhFTpMCXXxRXVvPH9fk88gb2ykoPcyo/kksnD2MS8b2JzJCwS5yKhTo4qva+gb+vG4P//3GNnYUVTE0I56vnDuMuRMHEIjUeXmR9lCgS7dQ3+D42wf7eHD5NjbtLWdA7158+dyhXJM7kNhApN/liYSEkwV6q80jMxtoZsvN7EMz22hmd7awjZnZr8xsm5m9b2aTOqJwCS+REcac8f156Wtns2R+Ln2TYvjenzdyzs+Ws3jFdqpq6vwuUSSktdpCN7P+QH/n3FozSwTWAHOdcx822eZS4KvApcA04L+cc9NOtl+10MU5x6odxTy0fBtvbSumd1yABTOGMH9GNslxAb/LE+mWTtZCb3UcunNuL7A3+LrCzDYBA4APm2x2BbDUef86vG1mvc2sf/CzIi0yM2YMS2fGsHTWfnKQ/16+jf989SMeW7mD66cP5uazh5CRGON3mSIho11npMwsG5gIvHPcqgHA7ibv84PLjv/8bWaWZ2Z5RUWapU+OmTQohce/NIWX7zyHWSMzeHTFds7+99f5lxc3sqf0sN/liYSENge6mSUAzwNfd86d0vXdzrnFzrlc51xuRkbGqexCwtyo/kk8eN0kXrv7XC7PyeTptz/m3PuX88/Pvc+uA1V+lyfSrbUp0M0sgBfmv3XOLWthkwJgYJP3WcFlIqdkaEYC91+TwxvfmsW1UwfxwroCzvv5G3zt2ffYsq/C7/JEuqW2jHIx4Algk3PuFyfY7EXgxuBol+lAmfrPpSNkpcTxwyvG8uY/z+bWc4by2qb9XPTLFdy6NI/1u0v9Lk+kW2nLKJezgZXABqAhuPg7wCAA59wjwdB/ELgYOAQscM6ddAiLRrnIqSg9dIRfv7WLJ/+xi7LDtZwzIp2Fs4czbUiqphWQHkEXFknYqayp4+m3P+bxlTs5UFlD7uAUFs4ezqyRGQp2CWsKdAlb1bX1/H71bh79v+3sKatmTGYSV0zIZPrQNEb3TyJKUwtImFGgS9g7UtfAn9YV8MTKnWzZ7500TYyJYsqQVKYPTVXAS9g4rQuLREJBdFQE/5Q7kH/KHUhheTVv7yzh7R3FvL2jmNc3FwKQEBPFlOwUpg9NY/rQNMZkKuAlvCjQJez0SYrl8pxMLs/JBKCwvJp3mgT88i3eRW0KeAk3CnQJe32SYvlcTiafOxrwFdW8s6PlgM9tEvBjFfASYtSHLj1eYUU17za24EvYVlgJQHx0ZLAPXgEv3YdOioq0Q1FFDe/sLD5hwE8bksb0oamMG5CsgJcup0AXOQ1FFTVNWvDFbG0S8LnZR1vwqYwdkKw7MEmnU6CLdKCmAf/OzmI+2t884KcN9VrxYzKTdCcm6XAKdJFOdKCyeQv+aMAHIo2R/RIZN6A347OSGZ+VzBl9E9WKl9OiQBfpQgcqa8jbVcL7+WXBRynl1d7t9aKjIhjdP4mcrGTGZXlBPywjgcgITVcgbaNAF/GRc45PSg6xPr+MDfmlvJ9fxgcFZVQdqQcgLjqSsZnJjAu24sdn9WZwahwRCnlpga4UFfGRmTE4LZ7BafGNFzs1NDh2HKhs1op/+u2PqanzJjRNjI1i3AAv3MdnJTNuQDJZKb008ZiclAJdxAcREcbwPokM75PIVZOyAKirb2BrYSXvB1vxGwrKeOLNHdTWe39Fp8ZHB0PeC/icgb3pmxTr548h3Yy6XES6sZq6erbsq2hsxb+fX8bWwkrqG7zf2z6JMY3dNOOykhk/IJm0BN1YO5ypy0UkRMVERQa7XXoDgwE4fKSeD/d6XTUb8stYn1/Ka5sLOdo2G9C7l9eKD7bkz+ibSJ/EGHXX9AAKdJEQ0ys6ksmDU5k8OLVxWUV1LRv3lDfrrnn5g32N6xNjoxjRJ4ERfRIZ0TeBEX0TGdEngf7JsQr6MKIuF5EwVXroCB/uKWdrYSVbCyvYur+SbYWVFFcdadwmPjqS4cFwH9EnwQv7PokM6N1Lo2y6KQ1bFJFGxZU1bCusZGthZfDZC/vCiprGbXoFIhkeDPnhwZAf0SeBgalxGjPvM/Whi0ijtIQY0hJimDY0rdnyskO1XrgXVrJ1vxf0q3YUs+y9gsZtoqMiGJaR0LxF3zeRwalxmqisG1CgiwgAyXEBcrNTyc1Obba8orq2eYt+fwVrPznIi+v3NG4TiDSGph9tzR/rq89Oiyc6SkHfVRToInJSibEBJg5KYeKglGbLq2rq2F50tDVfybbCCj4oKOOlDXsbR9xERhjZaXEM75PA0IwEhqTHMywjniHpCaTEBXRCtoMp0EXklMTHRDUZUnlMdW0924uOtua9rptthZW8vrmw8SIpgOReAYZmxDMkPZ6h6fGNgT8kPV6zVJ4iBbqIdKjYQCRjMpMZk5ncbHldfQP5Bw+z80AVOw5UsaOokp0Hqli1vZhlawuabTugdy8v6IOB77XsE8js3UsnZU9CgS4iXSIqMoLs9Hiy0+OZfdy6qpo6dhVXeWFfdPS5khfWFlBRU9e4XXRUBNlpccGQT2BoxrHWvbpwFOgi0g3Ex0S12Kp3zlFcdSQY8pXBln0V24uqWuzCOdqqH9ok8LPT4ukV3TO6cBToItJtmRnpCTGkJ8QwdUjz0Td19Q0UlB5uDPmdB07chZOZHMvQDG8c/YDesWT27uU9knvRLzk2bEbiKNBFJCRFRUY0Tks8e2TzdYeO1LHrwCF2HKhkZ7ALZ/uBKv53475mV8oCmEFGQkww5GPJTO51LPCD4Z8WHx0S3TkKdBEJO3HRUYzOTGJ0ZtKn1lXX1rO3rJo9pYcpKD3M3lLv9Z6yw2zeV8Hrmwuprm1o9pmYqIjGgO8fDPyjLX3vfSxx0f7Hqf8ViIh0odhAZOPImZY45yg9VEtB6WEv6EsPs7esuvH9m1sPUFhRTcNxs6akxAUaA75Zt07wdZ/E2E4foaNAFxFpwsxIiY8mJT6asQOSW9ymtr6B/eXV7GnSuvfCv5r8g4d4d2dx431kj4qMMPolxZLZO5bPT85i3pRBHV57q4FuZkuAy4BC59zYFtbPAv4M7AwuWuac+2EH1igi0q0EIiPISokjKyXuhNtUVNc2tuwbu3WC3TxH6jtnUsS2tNCfBB4Elp5km5XOucs6pCIRkTCQGBsgMTbAGX0Tu+w7Wx2r45xbAZR0QS0iInIaOmrw5Zlmtt7MXjazMSfayMxuM7M8M8srKirqoK8WERHomEBfCwx2zuUADwB/OtGGzrnFzrlc51xuRkZGB3y1iIgcddqB7pwrd85VBl+/BATMLP20KxMRkXY57UA3s34WvITKzKYG91l8uvsVEZH2acuwxWeBWUC6meUD3wcCAM65R4DPA18xszrgMPAF59eNSkVEerBWA905d20r6x/EG9YoIiI+Co8pxkRERIEuIhIuFOgiImFCgS4iEiYU6CIiYUKBLiISJhToIiJhQoEuIhImFOgiImFCgS4iEiYU6CIiYUKBLiISJhToIiJhQoEuIhImFOgiImFCgS4iEiYU6CIiYUKBLiISJhToIiJhQoEuIhImFOgiImFCgS4iEiYU6CIiYUKBLiISJhToIiJhQoEuIhImFOgiImFCgS4iEiYU6CIiYUKBLiISJhToIiJhotVAN7MlZlZoZh+cYL2Z2a/MbJuZvW9mkzq+TBERaU1bWuhPAhefZP0lwIjg4zbg4dMvS0RE2qvVQHfOrQBKTrLJFcBS53kb6G1m/TuqQBERaZuO6EMfAOxu8j4/uOxTzOw2M8szs7yioqIO+GoRETmqS0+KOucWO+dynXO5GRkZXfnVIiJhL6oD9lEADGzyPiu4TEQk/DTUw5FKqKls8lzR5H1FK+srYeIX4cyFHV5aRwT6i8AiM/sdMA0oc87t7YD9ikhrnIOGOqg/AnU1LTzXQH0dNNR6yxpfBx8Nxz23+Dq4/6OvG/dVG/zu4Pujr1v6PA4iAxARgMho73Vk9HGvW1p2sm3bs33wdX3tcWF7fPi2EMY15c2X1R5q238bi4DoRIhJgOiEY8/xGd6jE7Qa6Gb2LDALSDezfOD7QADAOfcI8BJwKbANOAQs6JRKRbqz2mqoLjv2qCnzAvVEIVt35LjnlrY72fZNPofrnJ/JIoIBHGgSxse9jog6FpYRURDo5b2PiGoe4ND8H4L6Jq+PVEH9wSbLjjRff3SZq++cn7OpQFzz8I1JhMT+zd/HJH56m5beB3qBWefX3ESrge6cu7aV9Q7o+L8dRLqKc16rq7q8eShXl0F1aZOQbml98FF/pP3fGxGAqBgv8KJigmEZ03xZoBf06t1kmxiIij7uuek+jt8umk+1jI8P4aOt2GYhHICIyI4+0qenof64vwaOD/9WXtfVeD9XY/i20Hrubj9zO3VEl4uI/xrq4VAJHDoAh0tPHMwnCuWGupPvPzIGYpObPHpD78HB10nNl8cme620qNgThGxwWYQu1G6XiEjvEYj1u5JuS4Eu3VNDgxfCVQegqsgL6qqi4Psmr48uP1TCSbseAnHBoA2Gb1w6pA47LqRbCOajn1GISAhQoEvXcM5rHTcN5JOG9IET95n2SvECOT4D0kfA4Bne67h0iE/z1jdrLSd5LWORMKdAl9PnHBSsgcJNx8K4MaibhPSJ+pljkiA+3QvklGzImnxsJEBcurcuPiO4TZrXDyoin6JAl1NXexg2PAfvPgr7NhxbHog7FsKJ/aHfeK/l3FJIx6WpO0OkgyjQpf0Ofgx5T8DapXD4IPQZA5f9EobN9kI6Ot7vCkV6JAW6tI1zsOMNeHcxbHnZG6M86jKYervXh93F421F5NMU6HJyNRWw7llY/Rgc+MjrLjnnG5B7EyS3OAebiPhEgS4tO7DVa42ve9abh2LAZLjyURhzpTemWkS6HQW6HNNQD1v/F955FHYs9y5+GXMVTL3NG3kiIt2aAl28i3Le+w2sfhxKP4GkAXDe92DSlyBB0xyLhAoFek+2b4PXGt/wR6irhsFnw4U/hpFzIFL/a4iEGv3W9jT1tbDpRXj3MfhklTdmPOdamHor9B3jd3UichoU6D1FxX5Y8yTkLYHKfZAyBC76N5hwnXepvIiEPAV6OHMO8ld7o1U2/smbj3r4BTD1ARj+Wc32JxJmFOjhqLYaPnjeC/K967y5UqbeClNugbRhflcnIp1EgR5OSnd7l+SveQoOl0DGKJjzCxg/z5vAX0TCmgI91DkHO1cEL8l/yVs28lKYdjtkn6NL8kV6EAV6KCtYAy99y3uOS4Ozvu5dkt97oN+ViYgPFOih6FAJvPYDr2sloQ9c/gCM+ydNQyvSwynQQ0lDvTdl7Ws/8G5ofOZCOPefvVuniUiPp0APFQVr4K/fhD1rvSs6L70f+o72uyoR6UYU6N3d8d0rVz0O4z6vk50i8ikK9O5K3Ssi0k4K9O6oYA389Ruw5z11r4hImynQuxN1r4jIaVCgdwfqXhGRDqBA95u6V0SkgyjQ/VJV7LXI1y5V94qIdAgFeldT94qIdBIFeldS94qIdKI23eHAzC42sy1mts3M7m1h/XwzKzKzdcHHLR1fagirKoYXvwaPnQ/le7zulfn/ozAXkQ7VagvdzCKBh4ALgHxgtZm96Jz78LhNf++cW9QJNYauhnpY+xS89kN1r4hIp2tLl8tUYJtzbgeAmf0OuAI4PtClKXWviEgXa0ugDwB2N3mfD0xrYburzWwm8BFwl3Nu9/EbmNltwG0AgwYNan+1oaDZ6JW+Gr0iIl2mo+4S/Bcg2zk3Hvg78FRLGznnFjvncp1zuRkZGR301d1EQz3kLYEHJ8N7T3vdK4tWw/hrFOYi0iXa0kIvAJreAicruKyRc664ydvHgZ+dfmkhJH8NvKTuFRHxV1sCfTUwwsyG4AX5F4Drmm5gZv2dc3uDby8HNnVold2VuldEpBtpNdCdc3Vmtgh4BYgEljjnNprZD4E859yLwNfM7HKgDigB5ndizf6rLPRCfNWDGr0iIt2GOed8+eLc3FyXl5fny3efEufg47dg9ROw6S/QUAvDzoML/1XdKyLSZcxsjXMut6V1ulK0NdVlsP533gnPos0QmwxTb4XcmyB9hN/ViYg0UqCfyJ51kPcEbHgOag9B5iS44iEYcxVEx/ldnYjIpyjQm6o9DBtf8LpVCvIgqpd3knPKzZA50e/qREROSoEOULzd61J572moLoX0M+Dif4ecL0Cv3n5XJyLSJj030OvrYMtLXrfKjjcgIgo+cxlMuQWyz9bQQxEJOT0v0Mv3ePfsXPsUVOyFpCyYfR9MugES+/ldnYjIKesZgd7QADvf8PrGt7wMrgGGnw9zfgEjLoTInnEYRCS8hXeSHSqBdc94/eMl2yEuDWYsgskLIHWI39WJiHSo8At057ypa1c/ARuXQV01DJwOs+6F0VdAVIzfFYqIdIrwCfQjVbDhj16Q73sfohNgwnWQezP0G+t3dSIinS70A71wszdSZf3voKYc+oyBOT+H8fMgJtHv6kREukxoBnrdEdj0otc3/vFbEBkNo+d6FwANnKYhhyLSI4VeoH/0Cvx5IVQVQUo2fPYHMPF6iE/3uzIREV+FXqCnZEPWFK9vfNh5ENFRN10SEQltoRfoGSPh2mf9rkJEpNtR81ZEJEwo0EVEwoQCXUQkTCjQRUTChAJdRCRMKNBFRMKEAl1EJEwo0EVEwoQ55/z5YrMi4ONT/Hg6cKADywl1Oh7N6Xgco2PRXDgcj8HOuYyWVvgW6KfDzPKcc7l+19Fd6Hg0p+NxjI5Fc+F+PNTlIiISJhToIiJhIlQDfbHfBXQzOh7N6Xgco2PRXFgfj5DsQxcRkU8L1Ra6iIgcR4EuIhImQi7QzexiM9tiZtvM7F6/6/GTmQ00s+Vm9qGZbTSzO/2uyW9mFmlm75nZ//hdi9/MrLeZPWdmm81sk5md6XdNfjGzu4K/Ix+Y2bNmFut3TZ0hpALdzCKBh4BLgNHAtWY22t+qfFUHfMM5NxqYDizs4ccD4E5gk99FdBP/BfzNOfcZIIceelzMbADwNSDXOTcWiAS+4G9VnSOkAh2YCmxzzu1wzh0Bfgdc4XNNvnHO7XXOrQ2+rsD7hR3gb1X+MbMsYA7wuN+1+M3MkoGZwBMAzrkjzrlSX4vyVxTQy8yigDhgj8/1dIpQC/QBwO4m7/PpwQHWlJllAxOBd3wuxU+/BO4BGnyuozsYAhQBvw52QT1uZvF+F+UH51wB8B/AJ8BeoMw597/+VtU5Qi3QpQVmlgA8D3zdOVfudz1+MLPLgELn3Bq/a+kmooBJwMPOuYlAFdAjzzmZWQreX/JDgEwg3syu97eqzhFqgV4ADGzyPiu4rMcyswBemP/WObfM73p8dBZwuZntwuuKO8/Mnva3JF/lA/nOuaN/sT2HF/A90WeBnc65IudcLbAMmOFzTZ0i1AJ9NTDCzIaYWTTeiY0Xfa7JN2ZmeH2km5xzv/C7Hj85577tnMtyzmXj/X/xunMuLFthbeGc2wfsNrORwUXnAx/6WJKfPgGmm1lc8HfmfML0BHGU3wW0h3OuzswWAa/gnale4pzb6HNZfjoLuAHYYGbrgsu+45x7yb+SpBv5KvDbYONnB7DA53p84Zx7x8yeA9bijQx7jzCdAkCX/ouIhIlQ63IREZETUKCLiIQJBbqISJhQoIuIhAkFuohImFCgi4iECQW6iEiY+P+j7fAYrrMYegAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9628\n"
     ]
    }
   ],
   "source": [
    "# 导入必要的库\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "from tqdm import * # tqdm用于显示进度条并评估任务时间开销\n",
    "import numpy as np\n",
    "import sys\n",
    "\n",
    "# 设置随机种子\n",
    "torch.manual_seed(0)\n",
    "\n",
    "# 定义模型、优化器、损失函数\n",
    "model = LeNet()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.02)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# 设置数据变换和数据加载器\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),  # 将数据转换为张量\n",
    "])\n",
    "\n",
    "# 加载训练数据\n",
    "train_dataset = datasets.MNIST(root='../data/mnist/', train=True, download=True, transform=transform)\n",
    "# 实例化训练数据加载器\n",
    "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)\n",
    "# 加载测试数据\n",
    "test_dataset = datasets.MNIST(root='../data/mnist/', train=False, download=True, transform=transform)\n",
    "# 实例化测试数据加载器\n",
    "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)\n",
    "\n",
    "# 设置epoch数并开始训练\n",
    "num_epochs = 10  # 设置epoch数\n",
    "loss_history = []  # 创建损失历史记录列表\n",
    "acc_history = []   # 创建准确率历史记录列表\n",
    "\n",
    "# tqdm用于显示进度条并评估任务时间开销\n",
    "for epoch in tqdm(range(num_epochs), file=sys.stdout):\n",
    "    # 记录损失和预测正确数\n",
    "    total_loss = 0\n",
    "    total_correct = 0\n",
    "    \n",
    "    # 批量训练\n",
    "    model.train()\n",
    "    for inputs, labels in train_loader:\n",
    "\n",
    "        # 预测、损失函数、反向传播\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        # 记录训练集loss\n",
    "        total_loss += loss.item()\n",
    "    \n",
    "    # 测试模型，不计算梯度\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for inputs, labels in test_loader:\n",
    "\n",
    "            # 预测\n",
    "            outputs = model(inputs)\n",
    "            # 记录测试集预测正确数\n",
    "            total_correct += (outputs.argmax(1) == labels).sum().item()\n",
    "        \n",
    "    # 记录训练集损失和测试集准确率\n",
    "    loss_history.append(np.log10(total_loss))  # 将损失加入损失历史记录列表，由于数值有时较大，这里取对数\n",
    "    acc_history.append(total_correct / len(test_dataset))# 将准确率加入准确率历史记录列表\n",
    "    \n",
    "    # 打印中间值\n",
    "    if epoch % 2 == 0:\n",
    "        tqdm.write(\"Epoch: {0} Loss: {1} Acc: {2}\".format(epoch, loss_history[-1], acc_history[-1]))\n",
    "\n",
    "# 使用Matplotlib绘制损失和准确率的曲线图\n",
    "import matplotlib.pyplot as plt\n",
    "plt.plot(loss_history, label='loss')\n",
    "plt.plot(acc_history, label='accuracy')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "# 输出准确率\n",
    "print(\"Accuracy:\", acc_history[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af348896",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
